#include "nanovoid_app.h"

NanoVoidOneBack::NanoVoidOneBack(int _Nx, int _Ny, ParameterSet _p, valueType _lsh_r, uint _lshK, uint _lshL)
        :
        OneStep(vals_len, _Nx * _Ny * 6, _Nx * _Ny, _lshK, _lshL, _lsh_r),
        Nx(_Nx), Ny(_Ny), size(_Nx * _Ny), lshK(_lshK), lshL(_lshL),
        p(_p), dp(0.0) {}


void NanoVoidOneBack::grab_vals(uint item, valueType *value_table, valueType *vals) {
    uint start_pos = 0;
    uint start_vals_pos = 0;
    uint i = 0, pg = 0;

    int c_x = item / Ny;
    int c_y = item % Ny;

    uint pd, root_item;
    int cc_x, cc_y;

    for (; i < lap_len_2nd; ++ i) {
        cc_x = c_x; cc_y = c_y;

        cc_x += dx[i];
        cc_y += dy[i];

        cc_x = max(cc_x, 0);      // smch: may be changed to mod operation
        cc_x = min(cc_x, Nx-1);
        cc_y = max(cc_y, 0);
        cc_y = min(cc_y, Ny-1);

        pd = inv.item2pd[((uint)cc_x)*Ny + ((uint)cc_y)];
        root_item = inv.d_item(inv.find_(pd));

        start_pos = 0;
        start_vals_pos = 0;

        for (pg = 0; pg < n_channels; ++ pg) {
            vals[start_vals_pos + i] = value_table[root_item + start_pos];
            start_pos += size;
            start_vals_pos += lap_len_2nd;
        }
    }

// this is old version of using coordinate to grab vals
//    Coordinate2d3c c(0, 0);
//    c.from_item(item, size, size * size);
//
//    for (uint i = 0; i < lap_len_2nd; i++) {
//        Coordinate2d3c cc(c);
//
//        int x = cc.x + dx[i];
//        int y = cc.y + dy[i];
//
//        x = max(x, 0);
//        x = min(x, size - 1);
//        y = max(y, 0);
//        y = min(y, size - 1);
//
//        cc.x = x;
//        cc.y = y;
//
//        uint this_item_c1 = cc.to_item_c1(size, size * size);
//
//        // b here
//        uint pd1 = inv.item2pd[this_item_c1];
//        uint root1 = inv.find_(pd1);
//        uint root_item1 = inv.d_item(root1);
//        vals[i] = value_table[root_item1];
//
//        vals[i + lap_len_2nd] = value_table[root_item1 + num_items];
//
//        vals[i + lap_len_2nd * 2] = value_table[root_item1 + num_items * 2];
//
//        vals[i + lap_len_2nd * 3] = value_table[root_item1 + num_items * 3];
//
//        vals[i + lap_len_2nd * 4] = value_table[root_item1 + num_items * 4];
//
//        vals[i + lap_len_2nd * 5] = value_table[root_item1 + num_items * 5];
//    }
}


void NanoVoidOneBack::forward_one_step(valueType *vals, uint c, valueType *new_v) {

    // std::cout << "in one back forward one step: " << c << std::endl;
    //  ensure non zero
    valueType energy_v = std::abs(p.energy_v0) + 0.001;
    valueType energy_i = std::abs(p.energy_i0) + 0.001;
    valueType kBT = std::abs(p.kBT0) + 0.001;
    valueType kappa_v = std::abs(p.kappa_v0) + 0.001;
    valueType kappa_i = std::abs(p.kappa_i0) + 0.001;
    valueType kappa_eta = std::abs(p.kappa_eta0) + 0.001;
    valueType r_bulk = std::abs(p.r_bulk0) + 0.001;
    valueType r_surf = std::abs(p.r_surf0) + 0.001;

    valueType p_casc = std::abs(p.p_casc0) + 0.001;
    valueType bias = std::abs(p.bias0) + 0.001;
    valueType vg = std::abs(p.vg0) + 0.001;
    valueType diff_v = std::abs(p.diff_v0) + 0.001;
    valueType diff_i = std::abs(p.diff_i0) + 0.001;
    valueType L = std::abs(p.L0) + 0.001;

    // accumulate_weight_derivative(vals, c);

    valueType back_vals[this->vals_len / 2];
    forward_one_step_vals(vals, back_vals);
    new_v[c] = back_vals[0];                               // should ensure no zero here
    new_v[c + num_items] = back_vals[lap_len_2nd];         // should ensure no zero here
    new_v[c + num_items * 2] = back_vals[lap_len_2nd * 2]; // should ensure no zero here

    valueType back_vals_full[this->vals_len];
    std::memcpy(back_vals_full, vals, this->vals_len * sizeof(valueType));

    back_vals_full[0] = back_vals[0];
    back_vals_full[lap_len_2nd] = back_vals[lap_len_2nd];
    back_vals_full[lap_len_2nd * 2] = back_vals[lap_len_2nd * 2];
    valueType cv_diff = back_vals[0] - vals[0];
    valueType ci_diff = back_vals[lap_len_2nd] - vals[lap_len_2nd];
    valueType eta_diff = back_vals[lap_len_2nd * 2] - vals[lap_len_2nd * 2];
    for (uint i = 1; i < lap_len_2nd; ++i) {
        back_vals_full[i] = vals[i] + cv_diff;
        back_vals_full[lap_len_2nd + i] = vals[lap_len_2nd + i] + ci_diff;
        back_vals_full[lap_len_2nd * 2 + i] = vals[lap_len_2nd * 2 + i] + eta_diff;
    }
    accumulate_weight_derivative(back_vals_full, c);

    valueType dt = 2e-2;
    valueType mv = diff_v * back_vals[0] / kBT; // detect division by zero
    valueType mi = diff_i * back_vals[lap_len_2nd] / kBT; // detect division by zero
    valueType Q = dt * mv;
    valueType P = dt * mi;
    valueType R = dt * (-L) * N;
    valueType cv = back_vals[0];
    valueType ci = back_vals[lap_len_2nd];
    valueType eta = back_vals[lap_len_2nd * 2];

    // ensure non zero
    if (cv < EPS)
        cv = EPS;

    if (ci < EPS)
        ci = EPS;

    if (eta < EPS)
        eta = EPS;

    valueType back_dloss[this->vals_len / 2];

    // ensure non zero
    valueType one_cv_ci = 1 - cv - ci;
    if (one_cv_ci < EPS)
        one_cv_ci = EPS;

    // compute dloss_dcv
    //      compute dloss_dcv_dcv_dcv
    valueType dloss_dcv[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        dloss_dcv[i] = vals[this->vals_len / 2 + i];
    }
    valueType QlapDlossDcv = Q * inner_product(dloss_dcv, lapw, lap_len_1st);
    valueType dloss_dcv_dcv_dcv = vals[vals_len / 2] + QlapDlossDcv * ((eta - 1) * (eta - 1) * kBT *
                                                                   (1 / cv + 1 / one_cv_ci) +
                                                                   2 * eta * eta);
    dloss_dcv_dcv_dcv += kappa_v / 2 * Q * Q * inner_product(vals + vals_len / 2, laplapw, lap_len_2nd);
    if (cv <= EPS || cv >= 1.0) {
        dloss_dcv_dcv_dcv = 0.0;
    }

    //      compute dloss_dci_dci_dcv
    valueType dloss_dci[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        dloss_dci[i] = vals[this->vals_len / 2 + lap_len_2nd + i];
    }
    valueType PlapDlossDci = P * inner_product(dloss_dci, lapw, lap_len_1st);
    valueType dloss_dci_dci_dcv = PlapDlossDci * ((eta - 1) * (eta - 1) * kBT *
                                                  1 / one_cv_ci);
    if (ci <= EPS || ci >= 1.0) {
        dloss_dci_dci_dcv = 0.0;
    }

    //      compute dloss_deta_deta_dcv
    valueType dloss_deta_deta_dcv = R * vals[this->vals_len / 2 + lap_len_2nd * 2];
    dloss_deta_deta_dcv *= (
            2 * (eta - 1) * (energy_v + kBT * (log_with_mask_single(cv, EPS) - log_with_mask_single(one_cv_ci, EPS))) +
            2 * eta * (2 * (cv - 1)));
    if (eta <= EPS || eta >= 1.0) {
        dloss_deta_deta_dcv = 0.0;
    }

    // final assign
    new_v[c + num_items * 3] = dloss_dcv_dcv_dcv + dloss_dci_dci_dcv + dloss_deta_deta_dcv;

    // compute dloss_dci
    //      compute dloss_dci_dci_dci
    // valueType PlapDlossDci = P * inner_product(dloss_dci, lapw, lap_len_1st);
    valueType dloss_dci_dci_dci = vals[vals_len / 2 + lap_len_2nd] + PlapDlossDci * ((eta - 1) * (eta - 1) * kBT *
                                                                                 (1 / ci + 1 / one_cv_ci) +
                                                                                 2 * eta * eta);
    dloss_dci_dci_dci += kappa_i / 2 * P * P * inner_product(vals + vals_len / 2 + lap_len_2nd, laplapw, lap_len_2nd);
    if (ci <= EPS || ci >= 1.0) {
        dloss_dci_dci_dci = 0.0;
    }

    //      compute dloss_dcv_dcv_dci
    valueType dloss_dcv_dcv_dci = QlapDlossDcv * ((eta - 1) * (eta - 1) * kBT *
                                                  1 / one_cv_ci);
    if (cv <= EPS || cv >= 1.0) {
        dloss_dcv_dcv_dci = 0.0;
    }

    //      compute dloss_deta_deta_dci
    valueType dloss_deta_deta_dci = R * vals[this->vals_len / 2 + lap_len_2nd * 2];
    dloss_deta_deta_dci *= (
            2 * (eta - 1) * (energy_i + kBT * (log_with_mask_single(ci, EPS) - log_with_mask_single(one_cv_ci, EPS))) +
            2 * eta * 2 * ci);
    if (eta <= EPS || eta >= 1.0) {
        dloss_deta_deta_dci = 0.0;
    }

    // final assign
    new_v[c + num_items * 4] = dloss_dci_dci_dci + dloss_dcv_dcv_dci + dloss_deta_deta_dci;

    // compute dloss_deta
    //      compute dloss_dcv_dcv_deta
    valueType dloss_dcv_dcv_deta = QlapDlossDcv * (2 * (eta - 1) * (energy_v + kBT * (log_with_mask_single(cv, EPS) -
                                                                                      log_with_mask_single(one_cv_ci,
                                                                                                           EPS))) +
                                                   2 * eta * 2 * (cv - 1));
    if (cv <= EPS || cv >= 1.0) {
        dloss_dcv_dcv_deta = 0.0;
    }

    //      compute dloss_dci_dci_deta
    valueType dloss_dci_dci_deta = PlapDlossDci * (2 * (eta - 1) * (energy_i + kBT * (log_with_mask_single(ci, EPS) -
                                                                                      log_with_mask_single(one_cv_ci,
                                                                                                           EPS))) +
                                                   2 * eta * 2 * ci);
    if (ci <= EPS || ci >= 1.0) {
        dloss_dci_dci_deta = 0.0;
    }

    //      compute dloss_deta_deta_deta
    valueType dloss_deta[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        dloss_deta[i] = vals[this->vals_len / 2 + lap_len_2nd * 2 + i];
    }

    valueType dloss_deta_deta_deta = vals[vals_len / 2 + lap_len_2nd * 2];
    dloss_deta_deta_deta += R * vals[vals_len / 2 + lap_len_2nd * 2] * 2 *
                            (energy_v * cv + energy_i * ci + kBT * (cv * log_with_mask_single(cv, EPS) 
                            + ci * log_with_mask_single(ci, EPS) + one_cv_ci * log_with_mask_single(one_cv_ci, EPS)) + (cv - 1) * (cv - 1) + ci * ci);

    dloss_deta_deta_deta -= R * kappa_eta * inner_product(dloss_deta, lapw, lap_len_1st);
    if (eta <= EPS || eta >= 1.0) {
        dloss_deta_deta_deta = 0.0;
    }

    // final assign
    new_v[c + num_items * 5] = dloss_dcv_dcv_deta + dloss_dci_dci_deta + dloss_deta_deta_deta;

    if (debug_on) {
        if (isnan(new_v[c + num_items * 3]) || isnan(new_v[c + num_items * 4]) || isnan(new_v[c + num_items * 5])) {
            fflush(stdout);
            cout << "detect nan value" << endl;
            cout << "vals: ";
            for (int i = 0; i < vals_len; ++i)
                cout << vals[i] << ", ";
            cout << endl;
        }
    }
}


void NanoVoidOneBack::forward_one_step_vals(valueType *vals, valueType *new_v) {
    valueType energy_v = std::abs(p.energy_v0) + 0.001;
    valueType energy_i = std::abs(p.energy_i0) + 0.001;
    valueType kBT = std::abs(p.kBT0) + 0.001;
    valueType kappa_v = std::abs(p.kappa_v0) + 0.001;
    valueType kappa_i = std::abs(p.kappa_i0) + 0.001;
    valueType kappa_eta = std::abs(p.kappa_eta0) + 0.001;
    valueType r_bulk = std::abs(p.r_bulk0) + 0.001;
    valueType r_surf = std::abs(p.r_surf0) + 0.001;

    valueType p_casc = std::abs(p.p_casc0) + 0.001;
    valueType bias = std::abs(p.bias0) + 0.001;
    valueType vg = std::abs(p.vg0) + 0.001;
    valueType diff_v = std::abs(p.diff_v0) + 0.001;
    valueType diff_i = std::abs(p.diff_i0) + 0.001;
    valueType L = std::abs(p.L0) + 0.001;

    // compute cv, ci
    valueType h_dfs_dcv[lap_len_1st];
    valueType h_dfs_dci[lap_len_1st];

    // construct h_dfs_dcv, h_dfs_dci
    for (uint i = 0; i < lap_len_1st; i++) {
        h_dfs_dcv[i] = 1.0;
        h_dfs_dci[i] = 1.0;

        //        h_dfs_dcv[i] = (vals[lap_len_2nd * 2 + i] - 1) * (vals[lap_len_2nd * 2 + i] - 1); // (eta-1)**2
        //        h_dfs_dci[i] = h_dfs_dcv[i];

        valueType log_cv = log_with_mask_single(vals[i], EPS);
        valueType log_ci = log_with_mask_single(vals[lap_len_2nd + i], EPS);
        valueType log_1_cv_ci = log_with_mask_single(1 - vals[i] - vals[i + lap_len_2nd], EPS);

        h_dfs_dcv[i] = h_dfs_dcv[i] * (energy_v + kBT * (log_cv - log_1_cv_ci));
        h_dfs_dci[i] = h_dfs_dci[i] * (energy_i + kBT * (log_ci - log_1_cv_ci));
        if ((1 - vals[i] - vals[i + lap_len_2nd]) < EPS) {
            h_dfs_dcv[i] = 0;
            h_dfs_dci[i] = 0;
        }

        h_dfs_dcv[i] = h_dfs_dcv[i] * (vals[lap_len_2nd * 2 + i] - 1) * (vals[lap_len_2nd * 2 + i] - 1); // (eta-1)**2
        h_dfs_dci[i] = h_dfs_dci[i] * (vals[lap_len_2nd * 2 + i] - 1) * (vals[lap_len_2nd * 2 + i] - 1);
    }

    valueType j_dfv_dcv[lap_len_1st];
    valueType j_dfv_dci[lap_len_1st];

    for (uint i = 0; i < lap_len_1st; i++) {
        j_dfv_dcv[i] = vals[lap_len_2nd * 2 + i] * vals[lap_len_2nd * 2 + i]; // eta**2
        j_dfv_dci[i] = j_dfv_dcv[i];

        j_dfv_dcv[i] = j_dfv_dcv[i] * 2 * (vals[i] - 1);
        j_dfv_dci[i] = j_dfv_dci[i] * 2 * vals[lap_len_2nd + i];
    }

    valueType dt = 2e-2;
    valueType mv = diff_v * vals[0] / kBT;
    valueType mi = diff_i * vals[lap_len_2nd] / kBT;

    valueType dt_mv_lap_h_dfs_dcv = dt * mv * inner_product(h_dfs_dcv, lapw, lap_len_1st);
    valueType dt_mv_lap_j_dfv_dcv = dt * mv * inner_product(j_dfv_dcv, lapw, lap_len_1st);
    valueType dt_mv_lap_lap_cv = -dt * mv * inner_product(vals, laplapw, lap_len_2nd);

    new_v[0] = vals[0] - (dt_mv_lap_h_dfs_dcv + dt_mv_lap_j_dfv_dcv + kappa_v * dt_mv_lap_lap_cv);

    valueType dt_mi_lap_h_dfs_dci = dt * mi * inner_product(h_dfs_dci, lapw, lap_len_1st);
    valueType dt_mi_lap_j_dfv_dci = dt * mi * inner_product(j_dfv_dci, lapw, lap_len_1st);
    valueType dt_mi_lap_lap_ci = -dt * mi * inner_product(vals + lap_len_2nd, laplapw, lap_len_2nd);

    new_v[0 + lap_len_2nd] =
            vals[lap_len_2nd] - (dt_mi_lap_h_dfs_dci + dt_mi_lap_j_dfv_dci + kappa_i * dt_mi_lap_lap_ci);

    // compute eta
    // fs
    valueType fs = energy_v * vals[0] + energy_i * vals[lap_len_2nd];
    fs = fs + kBT * (vals[0] * log_with_mask_single(vals[0], EPS));
    fs = fs + kBT * (vals[lap_len_2nd] * log_with_mask_single(vals[lap_len_2nd], EPS));
    fs = fs + kBT * ((1 - vals[0] - vals[lap_len_2nd]) * log_with_mask_single(1 - vals[0] - vals[lap_len_2nd], EPS));
    if ((1 - vals[0] - vals[lap_len_2nd]) < EPS) {
        fs = 0;
    }
    // fv
    valueType fv = (vals[0] - 1) * (vals[0] - 1) + vals[lap_len_2nd] * vals[lap_len_2nd];

    valueType dF_deta = N * (fs * 2 * (vals[lap_len_2nd * 2] - 1) + fv * 2 * vals[lap_len_2nd * 2] -
                             kappa_eta * inner_product(vals + lap_len_2nd * 2, lapw, lap_len_1st));

    new_v[0 + lap_len_2nd * 2] = vals[lap_len_2nd * 2] - dt * (-L) * dF_deta;

    if (std::signbit(new_v[0])) {
        new_v[0] = 1e-6;
    }

    if (std::signbit(new_v[0 + lap_len_2nd])) {
        new_v[0 + lap_len_2nd] = 1e-6;
    }

    if (std::signbit(new_v[0 + lap_len_2nd * 2])) {
        new_v[0 + lap_len_2nd * 2] = 1e-6;
    }

    if (new_v[0] >= 1.0) {
        new_v[0] = 1.0;
    }

    if (new_v[0 + lap_len_2nd] >= 1.0) {
        new_v[0 + lap_len_2nd] = 1.0;
    }

    if (new_v[0 + lap_len_2nd * 2] >= 1.0) {
        new_v[0 + lap_len_2nd * 2] = 1.0;
    }
}


void NanoVoidOneBack::merge_neighbor_into_n_list(uint item, PNBucket *t) {
    Coordinate2d3c c(0, 0);
    c.from_item(item, Nx, size);

    uint root_item = t->p_list;
    uint root_pd = inv.item2pd[root_item];

    for (uint i = 0; i < lap_len_2nd; i++) {
        Coordinate2d3c cc(c);
        cc.x += dx[i];
        cc.y += dy[i];

        cc.x = max(cc.x, 0);
        cc.x = min(cc.x, Nx - 1);
        cc.y = max(cc.y, 0);
        cc.y = min(cc.y, Ny - 1);

        if (cc.x == c.x && cc.y == c.y)
            continue;

        uint this_item = cc.to_item_c1(Nx, size);
//        if ((t->n_list).find(this_item) != (t->n_list).end())
//            continue;
        // this does not need to be implemented; because we automatically ensure no duplication.

        uint pd = inv.item2pd[this_item];

        if (inv.find_(pd) != root_pd) {
            //            (t->n_list).insert(this_item);
            pnb.n_list_hash.insert_no_duplicate(t->n_list_id, this_item);
        }
    }

//    set<uint>::iterator it = (t->n_list).find(item);
//    if (it != (t->n_list).end())
//        (t->n_list).erase(it);
    uint item_hash;
    uint it = pnb.n_list_hash.find(t->n_list_id, item, item_hash);
    if (it != UINT_NULL)
        pnb.n_list_hash.delete_(it);
//    sort((t->n_list).begin(), (t->n_list).end());
//    vector<uint>::iterator last = unique((t->n_list).begin(), (t->n_list).end());
//    (t->n_list).resize(distance((t->n_list).begin(), last));
//
//    vector<uint>::iterator it = find((t->n_list).begin(), (t->n_list).end(), item);
//    if (it != (t->n_list).end())
//        (t->n_list).erase(it);
}


void NanoVoidOneBack::move_out_neighbor_from_n_list(uint item, PNBucket *t) {
    Coordinate2d3c c(0, 0);
    c.from_item(item, Nx, size);

    uint root_item = t->p_list;
    uint root_pd= inv.item2pd[root_item];

    for (uint i = 0; i < lap_len_2nd; i++) {
        Coordinate2d3c cc(c);
        cc.x += dx[i];
        cc.y += dy[i];

        cc.x = max(cc.x, 0);
        cc.x = min(cc.x, Nx-1);
        cc.y = max(cc.y, 0);
        cc.y = min(cc.y, Ny-1);

        if (cc.x == c.x && cc.y == c.y)
            continue;

//        vector< uint >::iterator cc_it = find((t->n_list).begin(), (t->n_list).end(), cc.to_item_c1(size, num_items));
//        set< uint >::iterator cc_it = (t->n_list).find(cc.to_item_c1(size, this->num_items));
//        if (cc_it == (t->n_list).end())
//            continue;
        uint cc_hash;
        uint cc_it = pnb.n_list_hash.find(t->n_list_id, cc.to_item_c1(Nx, size), cc_hash);
        if (cc_it == UINT_NULL)
            continue;

        // check its neighbor
        bool clean_out = true;
        for (uint j = 0; j < lap_len_2nd; j++) {
            Coordinate2d3c c3(cc);
            c3.x += dx[i];
            c3.y += dy[i];

            c3.x = max(c3.x, 0);
            c3.x = min(c3.x, Nx-1);
            c3.y = max(c3.y, 0);
            c3.y = min(c3.y, Ny-1);

            if (inv.find_(inv.item2pd[c3.to_item_c1(Nx, size)]) == root_pd) {
                clean_out = false;
                break;
            }
        }
        if (clean_out) {
//            (t->n_list).erase(cc_it);
            pnb.n_list_hash.delete_(cc_it);
        }
    }
//    if ((t->n_list).find(item) != (t->n_list).end())
//        return ;
    uint item_hash;
    uint item_it = pnb.n_list_hash.find(t->n_list_id, item, item_hash);
    if (item_it != UINT_NULL)
        return ;

    bool add_in = false;
    for (uint i = 0; i < lap_len_2nd; ++ i) {
        Coordinate2d3c cc(c);
        cc.x += dx[i];
        cc.y += dy[i];

        cc.x = max(cc.x, 0);
        cc.x = min(cc.x, Nx-1);
        cc.y = max(cc.y, 0);
        cc.y = min(cc.y, Ny-1);

        if (cc.x == c.x && cc.y == c.y)
            continue;

        if (inv.find_(inv.item2pd[cc.to_item_c1(Nx, size)]) == root_pd) {
            add_in = true;
            break;
        }
    }
    if (add_in){
//        (t->n_list).insert(item);
        pnb.n_list_hash.insert_(t->n_list_id, item, item_hash);
    }
}


void NanoVoidOneBack::log_with_mask(valueType *mat, valueType eps, uint len) {
    for (uint i = 0; i < len; i++) {
        if (mat[i] < eps) {
            mat[i] = eps;
        }
        mat[i] = log(mat[i]);
    }
}


valueType NanoVoidOneBack::log_with_mask_single(valueType p, valueType eps) {
    if (p < eps) {
        p = eps;
    }
    return log(p);
}


void NanoVoidOneBack::masked_fill(valueType *mat, int *mask, valueType eps, uint len) {
    for (uint i = 0; i < len; i++) {
        if (mask[i] == 1) {
            mat[i] = eps;
        }
    }
}

void NanoVoidOneBack::encode_from_img_torch(valueType *cv, valueType *ci, valueType *eta, valueType *dloss_cv, valueType *dloss_ci, valueType *dloss_eta){
    Coordinate2d3c c(0, 0);
    for (c.x = 0; c.x < Nx; ++c.x) {
        for (c.y = 0; c.y < Ny; ++c.y) {
            uint item_1 = c.to_item_c1(Nx, size);
            old_v[item_1] = cv[item_1];                   // cv
            old_v[item_1 + num_items] = ci[item_1];       // ci
            old_v[item_1 + num_items * 2] = eta[item_1];  // eta
            old_v[item_1 + num_items * 3] = dloss_cv[item_1]; // dloss_dcv
            old_v[item_1 + num_items * 4] = dloss_ci[item_1]; // dloss_dci
            old_v[item_1 + num_items * 5] = dloss_eta[item_1]; // dloss_deta
        }
    }

    if (debug_on) {
        printf("after assign old_v\n");
        fflush(stdout);
    }

    // inv
    uint inv_size = (uint) size;
    for (uint i = 0; i < inv_size; i++) {
        inv.makeset(i);
    }

    if (debug_on)
        ;
//        inv.check_from_dfslist(num_items);

    valueType vals[vals_len];
    int item_lsh[lshK];
    uint item_k = 0, item = 0;

    uint handle_once = min(inv_size, (uint)1024*1024);

    uint item_vec[handle_once]; //[16777216]; //

    uint usize = inv_size;

    for (uint handle_start = 0; handle_start < usize; handle_start += handle_once) {
        if (debug_on) {
            printf("handle_start=%u\n", handle_start);
            fflush(stdout);
        }

        uint num_handled = min(usize - handle_start, handle_once);
        for (item_k = 0; item_k < num_handled; ++ item_k)
            item_vec[item_k] = handle_start + item_k;

        random_shuffle(item_vec, item_vec + num_handled);

        for (item_k = 0; item_k < num_handled; ++ item_k) {
            item = item_vec[item_k];
            if (debug_on)
                printf("before grab vals\n");
            grab_vals(item, old_v, vals);
            if (debug_on)
                printf("after grab vals\n");

            uint cl = 0;

            uint pnb_to_add = UINT_NULL;
            for (cl = 0; cl < L; ++ cl) {
                lsh.lsh(vals, cl, item_lsh);

                uint hp_it = hash_t.find(item_lsh, cl);
                if (hp_it != UINT_NULL) {
                    pnb_to_add = hash_t.hp.d[hp_it].pnb;
                    break;
                }
            }

            uint item_pd = inv.item2pd[item];
            PNBucket* pn_it = NULL;

            if (pnb_to_add != UINT_NULL) {
                pn_it = &(pnb.d[pnb_to_add]);
                uint ori_pn_plist = pn_it->p_list;
                uint pn_pd = inv.item2pd[pn_it->p_list];
                uint root = inv.union_(item_pd, pn_pd);
                pn_it->p_list = inv.d_item(root);

                if (pn_it->p_list != ori_pn_plist) {
                    for (cl = 0; cl < L; ++ cl) {
                        item2hp_hash.clear(item2hp_id[pn_it->p_list][cl]);
                        item2hp_hash.move_from_id_to_id(item2hp_id[ori_pn_plist][cl], \
                                            item2hp_id[pn_it->p_list][cl]);
                    }
                }
                if (debug_on)
                    printf("before merge_neighbor_into_n_list\n");
                merge_neighbor_into_n_list(item, pn_it);
                if (debug_on)
                    printf("after merge_neighbor_into_n_list\n");
            }
            else {
                pnb_to_add = pnb.new_elem();
                pn_it = &(pnb.d[pnb_to_add]);
                pn_it->p_list = item;
                //assert((pn_it->n_list).size() == 0);
                if (debug_on)
                    printf("before merge_neighbor_into_n_list\n");
                merge_neighbor_into_n_list(item, pn_it);
                if (debug_on)
                    printf("after merge_neighbor_into_n_list\n");
            }
            for (cl = 0; cl < L; ++ cl) {
                if (debug_on) {
                    printf("cl = %d\n", cl);
                }
                lsh.lsh(vals, cl, item_lsh);
                uint hp_it = hash_t.find(item_lsh, cl);
                if (hp_it == UINT_NULL || hash_t.hp.d[hp_it].pnb != pnb_to_add) {
                    uint hp_id = hash_t.hp.new_elem();
                    HashPointer* hp_it = &(hash_t.hp.d[hp_id]);
                    memcpy(hp_it->lsh_hash_code, item_lsh, sizeof(int)*K);
                    hp_it->hash_code = hash_t.hash_from_lsh(item_lsh);
                    hp_it->pnb = pnb_to_add;
                    hash_t.insert(hp_id, hp_it->hash_code, cl);
                    // bug line
                    item2hp_hash.insert_no_duplicate(item2hp_id[pn_it->p_list][cl], hp_id);
                }
            }
        }
    }
}


void NanoVoidOneBack::encode_from_img(valueType ***img, valueType ***dloss) {

    Coordinate2d3c c(0, 0);
    for (c.x = 0; c.x < Nx; ++c.x) {
        for (c.y = 0; c.y < Ny; ++c.y) {
            uint item_1 = c.to_item_c1(Nx, size);
            old_v[item_1] = img[c.x][c.y][0];                   // cv
            old_v[item_1 + num_items] = img[c.x][c.y][1];       // ci
            old_v[item_1 + num_items * 2] = img[c.x][c.y][2];   // eta
            old_v[item_1 + num_items * 3] = dloss[c.x][c.y][0]; // dloss_dcv
            old_v[item_1 + num_items * 4] = dloss[c.x][c.y][1]; // dloss_dci
            old_v[item_1 + num_items * 5] = dloss[c.x][c.y][2]; // dloss_deta
        }
    }

    if (debug_on) {
        printf("after assign old_v\n");
        fflush(stdout);
    }

    // inv
    uint inv_size = (uint) size;
    for (uint i = 0; i < inv_size; i++) {
        inv.makeset(i);
    }

    if (debug_on)
        ;
//        inv.check_from_dfslist(num_items);

    valueType vals[vals_len];
    int item_lsh[lshK];
    uint item_k = 0, item = 0;

    uint handle_once = min(inv_size, (uint)1024*1024);

    uint item_vec[handle_once]; //[16777216]; //

    uint usize = inv_size;

    for (uint handle_start = 0; handle_start < usize; handle_start += handle_once) {
        if (debug_on) {
            printf("handle_start=%u\n", handle_start);
            fflush(stdout);
        }

        uint num_handled = min(usize - handle_start, handle_once);
        for (item_k = 0; item_k < num_handled; ++ item_k)
            item_vec[item_k] = handle_start + item_k;

        random_shuffle(item_vec, item_vec + num_handled);

        for (item_k = 0; item_k < num_handled; ++ item_k) {
            item = item_vec[item_k];
            if (debug_on)
                printf("before grab vals\n");
            grab_vals(item, old_v, vals);
            if (debug_on)
                printf("after grab vals\n");

            uint cl = 0;

            uint pnb_to_add = UINT_NULL;
            for (cl = 0; cl < L; ++ cl) {
                lsh.lsh(vals, cl, item_lsh);

                uint hp_it = hash_t.find(item_lsh, cl);
                if (hp_it != UINT_NULL) {
                    pnb_to_add = hash_t.hp.d[hp_it].pnb;
                    break;
                }
            }

            uint item_pd = inv.item2pd[item];
            PNBucket* pn_it = NULL;

            if (pnb_to_add != UINT_NULL) {
                pn_it = &(pnb.d[pnb_to_add]);
                uint ori_pn_plist = pn_it->p_list;
                uint pn_pd = inv.item2pd[pn_it->p_list];
                uint root = inv.union_(item_pd, pn_pd);
                pn_it->p_list = inv.d_item(root);

                if (pn_it->p_list != ori_pn_plist) {
                    for (cl = 0; cl < L; ++ cl) {
                        item2hp_hash.clear(item2hp_id[pn_it->p_list][cl]);
                        item2hp_hash.move_from_id_to_id(item2hp_id[ori_pn_plist][cl], \
                                            item2hp_id[pn_it->p_list][cl]);
                    }
                }
                if (debug_on)
                    printf("before merge_neighbor_into_n_list\n");
                merge_neighbor_into_n_list(item, pn_it);
                if (debug_on)
                    printf("after merge_neighbor_into_n_list\n");
            }
            else {
                pnb_to_add = pnb.new_elem();
                pn_it = &(pnb.d[pnb_to_add]);
                pn_it->p_list = item;
                //assert((pn_it->n_list).size() == 0);
                if (debug_on)
                    printf("before merge_neighbor_into_n_list\n");
                merge_neighbor_into_n_list(item, pn_it);
                if (debug_on)
                    printf("after merge_neighbor_into_n_list\n");
            }
            for (cl = 0; cl < L; ++ cl) {
                if (debug_on) {
                    printf("cl = %d\n", cl);
                }
                lsh.lsh(vals, cl, item_lsh);
                uint hp_it = hash_t.find(item_lsh, cl);
                if (hp_it == UINT_NULL || hash_t.hp.d[hp_it].pnb != pnb_to_add) {
                    uint hp_id = hash_t.hp.new_elem();
                    HashPointer* hp_it = &(hash_t.hp.d[hp_id]);
                    memcpy(hp_it->lsh_hash_code, item_lsh, sizeof(int)*K);
                    hp_it->hash_code = hash_t.hash_from_lsh(item_lsh);
                    hp_it->pnb = pnb_to_add;
                    hash_t.insert(hp_id, hp_it->hash_code, cl);
                    // bug line
                    item2hp_hash.insert_no_duplicate(item2hp_id[pn_it->p_list][cl], hp_id);
                }
            }
        }
    }

    // v2 version of encoding from image
//    // inv
//    uint inv_size = (uint) num_items;
//    for (uint i = 0; i < inv_size; i++) {
//        inv.makeset(i);
//    }
//
//    if (debug_on);
//    //        inv.check_from_dfslist(num_items);
//
//    valueType vals[vals_len];
//    int item_lsh[K];
//    uint item_k = 0, item = 0;
//
//    uint item_vec[num_items];
//    for (item_k = 0; item_k < num_items; ++item_k) {
//        item_vec[item_k] = item_k;
//    }
//    random_shuffle(item_vec, item_vec + num_items);
//
//    for (item_k = 0; item_k < num_items; ++item_k) {
//        item = item_vec[item_k];
//        //printf("encode item=%u\n", item);
//
//        grab_vals(item, old_v, vals);
//        if (debug_on) {
//            fflush(stdout);
//            cout << "vals: ";
//            for (int i = 0; i < vals_len; ++i)
//                cout << vals[i] << ", ";
//            cout << endl;
//        }
//
//        uint cl = 0;
//        // determine if can add to other pnb;
//        uint pnb_to_add = UINT_NULL;
//        for (cl = 0; cl < L; ++cl) {
//            lsh.lsh(vals, cl, item_lsh);
//            if (debug_on)
//            {
//                cout << "k-lsh: ";
//                for (int i = 0; i < K; ++i)
//                    cout << item_lsh[i] << ", ";
//                cout << endl;
//            }
//            uint hp_it = hash_t.find(item_lsh, cl);
//            if (hp_it != UINT_NULL) {
//                pnb_to_add = hash_t.hp.d[hp_it].pnb;
//                break;
//            }
//        }
//
//        uint item_pd = inv.item2pd[item];
//        PNBucket *pn_it = NULL;
//        if (pnb_to_add != UINT_NULL) {
//            // merge into this bucket
//            pn_it = &(pnb.d[pnb_to_add]);
//            uint ori_pn_plist = pn_it->p_list;
//            uint pn_pd = inv.item2pd[pn_it->p_list];
//            uint root = inv.union_(item_pd, pn_pd);
//            pn_it->p_list = inv.d_item(root);
//
//            if (pn_it->p_list != ori_pn_plist) {
//                for (cl = 0; cl < L; ++cl) {
//                    item2hp[pn_it->p_list][cl].clear();
//                    item2hp[pn_it->p_list][cl].insert(item2hp[ori_pn_plist][cl].begin(),
//                                                      item2hp[ori_pn_plist][cl].end());
//                    item2hp[ori_pn_plist][cl].clear();
//                }
//            }
//            // no need to update old_v (it is their accurate value).
//            merge_neighbor_into_n_list(item, pn_it);
//        } else {
//            pnb_to_add = pnb.new_elem();
//            pn_it = &(pnb.d[pnb_to_add]);
//            pn_it->p_list = item;
//            assert((pn_it->n_list).size() == 0);
//            merge_neighbor_into_n_list(item, pn_it);
//        }
//
//        for (cl = 0; cl < L; ++cl) {
//            lsh.lsh(vals, cl, item_lsh);
//
//            uint hp_it = hash_t.find(item_lsh, cl);
//            if (hp_it == UINT_NULL) {
//                uint hp_id = hash_t.hp.new_elem();
//                HashPointer *hp_it = &(hash_t.hp.d[hp_id]);
//                std::memcpy(hp_it->lsh_hash_code, item_lsh, sizeof(int) * K);
//                hp_it->hash_code = hash_t.hash_from_lsh(item_lsh);
//                hp_it->pnb = pnb_to_add;
//                hash_t.insert(hp_id, hp_it->hash_code, cl);
//                item2hp[pn_it->p_list][cl].insert(hp_id);
//            }
//        }
//
//        if (item % size == 0) {
//            // printf("after processing %u items\n", item);
//            //hash_t.print_hash_table(inv);
//        }
//    }
//
//    // inv.check_from_dfslist(this->num_items);
}


valueType ***NanoVoidOneBack::decode_to_img() {
    Coordinate2d3c c(0, 0);

    valueType ***mtx = new valueType **[Nx];
    for (c.x = 0; c.x < Nx; c.x++) {
        valueType **row = new valueType *[Ny];
        for (c.y = 0; c.y < Ny; c.y++) {
            uint item = c.to_item_c1(Nx, num_items);
            uint item_pd = inv.item2pd[item];
            uint root = inv.find_(item_pd);
            uint root_item = inv.d_item(root);
            valueType *channel_arr = new valueType[n_channels];
            for (uint channel = 0; channel < n_channels; channel++) {
                channel_arr[channel] = old_v[root_item + num_items * channel];
            }
            row[c.y] = channel_arr;
        }
        mtx[c.x] = row;
    }
    return mtx;
}

valueType** NanoVoidOneBack::decode_to_img_torch() {
    Coordinate2d3c c(0, 0);

    valueType** mtx = new valueType*[n_channels];
    valueType* cv = new valueType[Nx*Ny];
    valueType* ci = new valueType[Nx*Ny];
    valueType* eta = new valueType[Nx*Ny];
    valueType* dloss_cv = new valueType[Nx*Ny];
    valueType* dloss_ci = new valueType[Nx*Ny];
    valueType* dloss_eta = new valueType[Nx*Ny];
    for (c.x = 0; c.x < Nx; ++ c.x) {
        for (c.y = 0; c.y < Nx; c.y++) {
            uint item = c.to_item_c1(Nx, size);
            uint item_pd = inv.item2pd[item];
            uint root = inv.find_(item_pd);
            uint root_item = inv.d_item(root);
            cv[item] = old_v[root_item];
            ci[item] = old_v[root_item + num_items];
            eta[item] = old_v[root_item + num_items * 2];
            dloss_cv[item] = old_v[root_item + num_items * 3];
            dloss_ci[item] = old_v[root_item + num_items * 4];
            dloss_eta[item] = old_v[root_item + num_items * 5];
        }
    }
    mtx[0] = cv;
    mtx[1] = ci;
    mtx[2] = eta;
    mtx[3] = dloss_cv;
    mtx[4] = dloss_ci;
    mtx[5] = dloss_eta;
    return mtx;
}


void NanoVoidOneBack::accumulate_weight_derivative(valueType *vals, uint c) {
    valueType energy_v = std::abs(p.energy_v0) + 0.001;
    valueType energy_i = std::abs(p.energy_i0) + 0.001;
    valueType kBT = std::abs(p.kBT0) + 0.001;
    valueType kappa_v = std::abs(p.kappa_v0) + 0.001;
    valueType kappa_i = std::abs(p.kappa_i0) + 0.001;
    valueType kappa_eta = std::abs(p.kappa_eta0) + 0.001;
    valueType r_bulk = std::abs(p.r_bulk0) + 0.001;
    valueType r_surf = std::abs(p.r_surf0) + 0.001;

    valueType p_casc = std::abs(p.p_casc0) + 0.001;
    valueType bias = std::abs(p.bias0) + 0.001;
    valueType vg = std::abs(p.vg0) + 0.001;
    valueType diff_v = std::abs(p.diff_v0) + 0.001;
    valueType diff_i = std::abs(p.diff_i0) + 0.001;
    valueType L = std::abs(p.L0) + 0.001;
    
    valueType dt = 2e-2;
    valueType mv = diff_v * vals[0] / kBT; // diffv * cv / kBT
    valueType mi = diff_i * vals[lap_len_2nd] / kBT; // diffi * ci / kBT
    // valueType L = std::abs(p.L0);
    // valueType energy_v = std::abs(p.energy_v0);
    // valueType kBT = std::abs(p.kBT0);
    // valueType kappa_v = std::abs(p.kappa_v0);
    // valueType energy_i = std::abs(p.energy_i0);
    // valueType kappa_i = std::abs(p.kappa_i0);
    // valueType kappa_eta = std::abs(p.kappa_eta0);

    // get size of this bucket
    uint root_pd = inv.item2pd[c];
    uint bucket_size = inv.d_size(root_pd);
    // uint bucket_size = 1;

    // 1 - cv - ci
    valueType one_cv_ci = 1 - vals[0] - vals[lap_len_2nd];
    if (one_cv_ci < EPS)
        one_cv_ci = EPS;

    valueType cv = vals[0];
    valueType ci = vals[lap_len_2nd];
    valueType eta = vals[lap_len_2nd * 2];

    // energy v
    valueType eta_1_sq[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        eta_1_sq[i] = (vals[lap_len_2nd * 2 + i] - 1) * (vals[lap_len_2nd * 2 + i] - 1);
    }
    valueType dcv_dev = dt * mv * inner_product(eta_1_sq, lapw, lap_len_1st);
    valueType dci_dev = 0.0;
    valueType deta_dev = dt * (-L) * N * 2 * (vals[lap_len_2nd * 2] - 1) * vals[0];
    if (p.energy_v0 <= 0.0) {
        dcv_dev = - dcv_dev;
        dci_dev = - dci_dev;
        deta_dev = - deta_dev;
    }
    // dp.energy_v0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dev + vals[lap_len_2nd * 4] * dci_dev
    //         + vals[lap_len_2nd * 5] * deta_dev);
    // dp.energy_v0 += bucket_size * (1 * dcv_dev + 1 * dci_dev+ 1 * deta_dev);

    // energy i
    valueType dcv_dei = 0.0;
    valueType dci_dei = dt * mi * inner_product(eta_1_sq, lapw, lap_len_1st);
    valueType deta_dei = dt * (-L) * N * 2 * (vals[lap_len_2nd * 2] - 1) * vals[lap_len_2nd];
    if (p.energy_i0 <= 0) {
        dcv_dei = - dcv_dei;
        dci_dei = - dci_dei;
        deta_dei = - deta_dei;
    }
    // dp.energy_i0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dei + vals[lap_len_2nd * 4] * dci_dei
    //         + vals[lap_len_2nd * 5] * deta_dei);
    // dp.energy_i0 += bucket_size * (dcv_dei + dci_dei + deta_dei);

    // kBT
    valueType eta_1_sq_cv[lap_len_1st];
    valueType eta_1_sq_ci[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        eta_1_sq_cv[i] = eta_1_sq[i] * (log_with_mask_single(vals[0], EPS) - log_with_mask_single(one_cv_ci, EPS));
        eta_1_sq_ci[i] =
                eta_1_sq[i] * (log_with_mask_single(vals[lap_len_2nd], EPS) - log_with_mask_single(one_cv_ci, EPS));
    }
    valueType dcv_dkBT = dt * mv * inner_product(eta_1_sq_cv, lapw, lap_len_1st);
    valueType dci_dkBT = dt * mv * inner_product(eta_1_sq_ci, lapw, lap_len_1st);
    valueType deta_dkBT = dt * (-L) * N * 2 * (vals[lap_len_2nd * 2] - 1) *
                          (vals[0] * log_with_mask_single(vals[0], EPS) +
                           vals[lap_len_2nd] * log_with_mask_single(vals[lap_len_2nd], EPS) +
                           one_cv_ci * log_with_mask_single(one_cv_ci, EPS));
    if (p.kBT0 <= 0) {
        dcv_dkBT = - dcv_dkBT;
        dci_dkBT = - dci_dkBT;
        deta_dkBT = - deta_dkBT;
    }
    // dp.kBT0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dkBT + vals[lap_len_2nd * 4] * dci_dkBT
    //         + vals[lap_len_2nd * 5] * deta_dkBT);
    // dp.kBT0 += bucket_size * ( dcv_dkBT + dci_dkBT + deta_dkBT);

    // kappa v
    valueType dcv_dkappa_v = -dt * mv * inner_product(vals, laplapw, lap_len_2nd);
    valueType dci_dkappa_v = 0.0;
    valueType deta_dkappa_v = 0.0;
    if (p.kappa_v0 <= 0) {
        dcv_dkappa_v = - dcv_dkappa_v;
        dci_dkappa_v = - dci_dkappa_v;
        deta_dkappa_v = - deta_dkappa_v;
    }
    // dp.kappa_v0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dkappa_v + vals[lap_len_2nd * 4] * dci_dkappa_v
    //         + vals[lap_len_2nd * 5] * deta_dkappa_v);
    // dp.kappa_v0 += bucket_size * (dcv_dkappa_v + dci_dkappa_v + deta_dkappa_v);

    // kappa i
    valueType dcv_dkappa_i = 0.0;
    valueType dci_dkappa_i = -dt * mi * inner_product(vals + lap_len_2nd, laplapw, lap_len_2nd);
    valueType deta_dkappa_i = 0.0;
    if (p.kappa_i0 <= 0) {
        dcv_dkappa_i = - dcv_dkappa_i;
        dci_dkappa_i = - dci_dkappa_i;
        deta_dkappa_i = - deta_dkappa_i;
    }
    // dp.kappa_i0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dkappa_i + vals[lap_len_2nd * 4] * dci_dkappa_i
    //         + vals[lap_len_2nd * 5] * deta_dkappa_i);
    // dp.kappa_i0 += bucket_size * (dcv_dkappa_i + dci_dkappa_i + deta_dkappa_i);

    // kappa eta
    valueType dcv_dkappa_eta = 0.0;
    valueType dci_dkappa_eta = 0.0;
    valueType deta_dkappa_eta = dt * L * N * inner_product(vals + lap_len_2nd * 2, lapw, lap_len_1st);
    if (p.kappa_eta0 <= 0) {
        dcv_dkappa_eta = - dcv_dkappa_eta;
        dci_dkappa_eta = - dci_dkappa_eta;
        deta_dkappa_eta = - deta_dkappa_eta;
    }
    // dp.kappa_eta0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dkappa_eta + vals[lap_len_2nd * 4] * dci_dkappa_eta
    //         + vals[lap_len_2nd * 5] * deta_dkappa_eta);
    // dp.kappa_eta0 += bucket_size * (dcv_dkappa_eta + dci_dkappa_eta + deta_dkappa_eta);

    // diff v
    valueType eta_1_sq_eta_sq_cv[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        eta_1_sq_eta_sq_cv[i] = eta_1_sq[i] * (energy_v + kBT * (log_with_mask_single(vals[0], EPS) -
                                                                 log_with_mask_single(one_cv_ci, EPS))) +
                                vals[lap_len_2nd] * vals[lap_len_2nd] * 2 * (vals[0] - 1);
    }
    valueType dcv_ddiff_v = dt * vals[0] / kBT * (inner_product(eta_1_sq_eta_sq_cv, lapw, lap_len_1st) -
                                                  kappa_v * inner_product(vals, laplapw, lap_len_2nd));
    valueType dci_ddiff_v = 0.0;
    valueType deta_ddiff_v = 0.0;
    if (p.diff_v0 <= 0) {
        dcv_ddiff_v = - dcv_ddiff_v;
        dci_ddiff_v = - dci_ddiff_v;
        deta_ddiff_v = - deta_ddiff_v;
    }
    // dp.diff_v0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_ddiff_v + vals[lap_len_2nd * 4] * dci_ddiff_v
    //         + vals[lap_len_2nd * 5] * deta_ddiff_v);
    // dp.diff_v0 += bucket_size * (dcv_ddiff_v + dci_ddiff_v + deta_ddiff_v);
    

    // diff i
    valueType dcv_ddiff_i = 0.0;
    valueType eta_1_sq_eta_sq_ci[lap_len_1st];
    for (uint i = 0; i < lap_len_1st; ++i) {
        eta_1_sq_eta_sq_ci[i] = eta_1_sq[i] * (energy_i + kBT * (log_with_mask_single(vals[lap_len_2nd], EPS) -
                                                                 log_with_mask_single(one_cv_ci, EPS))) +
                                vals[lap_len_2nd] * vals[lap_len_2nd] * 2 * vals[lap_len_2nd];
    }
    valueType dci_ddiff_i = dt * vals[lap_len_2nd] / kBT * (inner_product(eta_1_sq_eta_sq_ci, lapw, lap_len_1st) -
                                                            kappa_i *
                                                            inner_product(vals + lap_len_2nd, laplapw, lap_len_2nd));
    valueType deta_ddiff_i = 0.0;
    if (p.diff_i0 <= 0) {
        dcv_ddiff_i = - dcv_ddiff_i;
        dci_ddiff_i = - dci_ddiff_i;
        deta_ddiff_i = - deta_ddiff_i;
    }
    // dp.diff_i0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_ddiff_i + vals[lap_len_2nd * 4] * dci_ddiff_i
    //         + vals[lap_len_2nd * 5] * deta_ddiff_i);
    //  dp.diff_i0 += bucket_size * (dcv_ddiff_i + dci_ddiff_i + deta_ddiff_i);

    // L
    valueType dcv_dL = 0.0;
    valueType dci_dL = 0.0;
    valueType fs = energy_v * vals[0] + energy_i * vals[lap_len_2nd] + kBT *
                                                                       (vals[0] * log_with_mask_single(vals[0], EPS) +
                                                                        vals[lap_len_2nd] *
                                                                        log_with_mask_single(vals[lap_len_2nd], EPS) +
                                                                        one_cv_ci *
                                                                        log_with_mask_single(one_cv_ci, EPS));
    if (one_cv_ci < EPS) {
        fs = 0;
    }
    valueType fv = (vals[0] - 1) * (vals[0] - 1) + vals[lap_len_2nd] * vals[lap_len_2nd];
    valueType deta_dL = -dt * N * (fs * 2 * (vals[lap_len_2nd * 2] - 1) + fv * 2 * vals[lap_len_2nd * 2] -
                                   kappa_eta * inner_product(vals + lap_len_2nd * 2, lapw, lap_len_1st));
    if (p.L0 <= 0) {
        dcv_dL = -dcv_dL;
        dci_dL = -dci_dL;
        deta_dL = -deta_dL;
    }
    // dp.L0 += bucket_size * (vals[lap_len_2nd * 3] * dcv_dL + vals[lap_len_2nd * 4] * dci_dL
    //         + vals[lap_len_2nd * 5] * deta_dL);
    // dp.L0 += bucket_size * (dcv_dL + dci_dL + deta_dL);
    // v2 integral accumulation
    // if (one_cv_ci <= EPS || one_cv_ci >= 1.0) {
    //     // accumulate nothing
    // }
    // else {
        if (!(cv <= EPS || cv >= 1.0)) {
            // cv terms
            // if (!(one_cv_ci <= EPS)) {
                dp.energy_v0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dev;
                dp.energy_i0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dei;
                dp.kBT0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dkBT;
            // }
            dp.kappa_v0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dkappa_v;
            dp.kappa_i0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dkappa_i;
            dp.kappa_eta0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dkappa_eta;
            dp.diff_v0 += bucket_size * vals[lap_len_2nd * 3] * dcv_ddiff_v;
            dp.diff_i0 += bucket_size * vals[lap_len_2nd * 3] * dcv_ddiff_i;
            dp.L0 += bucket_size * vals[lap_len_2nd * 3] * dcv_dL;
        }
        if (!(ci <= EPS || ci >= 1.0)) {
            // ci terms
            // if (!(one_cv_ci <= EPS)) {
                dp.energy_v0 += bucket_size * vals[lap_len_2nd * 4] * dcv_dev;
                dp.energy_i0 += bucket_size * vals[lap_len_2nd * 4] * dci_dei;
                dp.kBT0 += bucket_size * vals[lap_len_2nd * 4] * dci_dkBT;
            // }
            dp.kappa_v0 += bucket_size * vals[lap_len_2nd * 4] * dci_dkappa_v;
            dp.kappa_i0 += bucket_size * vals[lap_len_2nd * 4] * dci_dkappa_i;
            dp.kappa_eta0 += bucket_size * vals[lap_len_2nd * 4] * dci_dkappa_eta;
            dp.diff_v0 += bucket_size * vals[lap_len_2nd * 4] * dci_ddiff_v;
            dp.diff_i0 += bucket_size * vals[lap_len_2nd * 4] * dci_ddiff_i;
            dp.L0 += bucket_size * vals[lap_len_2nd * 4] * dci_dL;
        }
        if (!(eta <= EPS || eta >= 1.0)) {
            // eta terms
            // if (!(one_cv_ci <= EPS)) {
                dp.energy_v0 += bucket_size * vals[lap_len_2nd * 5] * dcv_dev;
                dp.energy_i0 += bucket_size * vals[lap_len_2nd * 5] * deta_dei;
                dp.kBT0 += bucket_size * vals[lap_len_2nd * 5] * deta_dkBT;
            // }
            dp.kappa_v0 += bucket_size * vals[lap_len_2nd * 5] * deta_dkappa_v;
            dp.kappa_i0 += bucket_size * vals[lap_len_2nd * 5] * deta_dkappa_i;
            dp.kappa_eta0 += bucket_size * vals[lap_len_2nd * 5] * deta_dkappa_eta;
            dp.diff_v0 += bucket_size * vals[lap_len_2nd * 5] * deta_ddiff_v;
            dp.diff_i0 += bucket_size * vals[lap_len_2nd * 5] * deta_ddiff_i;
            dp.L0 += bucket_size * vals[lap_len_2nd * 5] * deta_dL;
        }
    // }


    // if (!(cv <= EPS || cv >= 1.0) || !(ci <= EPS || ci >= 1.0) || !(eta <= EPS || eta >= 1.0)) {
    //    printf("one pixel accumulate\n");
    //    this->print_derivative(); 
    // }  
}


void NanoVoidOneBack::print_derivative() {
    fflush(stdout);
    cout << "derivative of weight: " << endl;
    cout << "energy_v: " << dp.energy_v0 << endl;
    cout << "energy_i: " << dp.energy_i0 << endl;
    cout << "kBT: " << dp.kBT0 << endl;
    cout << "kappa_v: " << dp.kappa_v0 << endl;
    cout << "kappa_i: " << dp.kappa_i0 << endl;
    cout << "kappa_eta: " << dp.kappa_eta0 << endl;
    cout << "diff_v: " << dp.diff_v0 << endl;
    cout << "diff_i: " << dp.diff_i0 << endl;
    cout << "L: " << dp.L0 << endl;
}


void NanoVoidOneBack::assign_vals(valueType *old_v, uint c_old, valueType *new_v, uint c_new) {
    uint start_pos = 0;
    c_new = c_new % size;
    c_old = c_old % size;
    for (uint pg = 0; pg < n_channels; ++ pg) {
        new_v[start_pos + c_new] = old_v[start_pos + c_old];
        start_pos += size;
    }
//    new_v[c_new % num_items] = old_v[c_old % num_items];
//    new_v[(c_new % num_items) + num_items] = old_v[(c_old % num_items) + num_items];
//    new_v[(c_new % num_items) + num_items * 2] = old_v[(c_old % num_items) + num_items * 2];
//    new_v[(c_new % num_items) + num_items * 3] = old_v[(c_old % num_items) + num_items * 3];
//    new_v[(c_new % num_items) + num_items * 4] = old_v[(c_old % num_items) + num_items * 4];
//    new_v[(c_new % num_items) + num_items * 5] = old_v[(c_old % num_items) + num_items * 5];
}

valueType* NanoVoidOneBack::decode_derivative() {
    valueType *dloss_dp = new valueType[9];
    dloss_dp[0] = dp.energy_v0;
    dloss_dp[1] = dp.energy_i0;
    dloss_dp[2] = dp.kBT0;
    dloss_dp[3] = dp.kappa_v0;
    dloss_dp[4] = dp.kappa_i0;
    dloss_dp[5] = dp.kappa_eta0;
    dloss_dp[6] = dp.diff_v0;
    dloss_dp[7] = dp.diff_i0;
    dloss_dp[8] = dp.L0;

    return dloss_dp;
}


const int NanoVoidOneBack::dx[] = {0, 1, 0, -1, 0, 1, -1, 1, -1, 2, 0, -2, 0};
const int NanoVoidOneBack::dy[] = {0, 0, 1, 0, -1, 1, 1, -1, -1, 0, 2, 0, -2};
const valueType NanoVoidOneBack::laplapw[] = {20, -8, -8, -8, -8, 2, 2, 2, 2, 1, 1, 1, 1};
const valueType NanoVoidOneBack::lapw[] = {-4, 1, 1, 1, 1};